{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Autograd & Modules\n",
    "\n",
    "## Back to basics\n",
    "\n",
    "Training loop over CIFAR10 (40,000 train images, 10,000 test images). What happens if you\n",
    "- Remove the `ReLU()`? \n",
    "- Increase the learning rate?\n",
    "- Stack more layers? \n",
    "- Don't normalize the input?\n",
    "- Perform more epochs?\n",
    "\n",
    "Can you completely overfit the training set (i.e. get 100% accuracy?)\n",
    "\n",
    "This code is highly non-modulable. Can you create functions for each specific task and train it from the command line, e.g. something like \n",
    "`python main.py --batch_size 64 --n_epochs 10 --lr 0.01`?\n",
    "(hint: see [this](https://github.com/pytorch/examples/blob/master/mnist/main.py))\n",
    "\n",
    "Your training went well. Good. Why not save the weights of the network (`net.state_dict()`) using `torch.save()`?\n",
    "\n",
    "You have a GPU (remote or local). Where do you put the magic `.cuda()` to switch the training to a GPU? Is it faster?"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Files already downloaded and verified\n",
      "Files already downloaded and verified\n",
      "Epoch 0\n",
      "Train loss 0.0367, Train accuracy 24.38%\n",
      "Train loss 0.0300, Train accuracy 33.84%\n",
      "Train loss 0.0291, Train accuracy 36.24%\n",
      "Train loss 0.0287, Train accuracy 37.26%\n",
      "Train loss 0.0282, Train accuracy 38.20%\n",
      "Train loss 0.0280, Train accuracy 38.71%\n",
      "Train loss 0.0278, Train accuracy 39.22%\n",
      "Train loss 0.0278, Train accuracy 39.53%\n",
      "Epoch 1\n",
      "Train loss 0.0282, Train accuracy 47.66%\n",
      "Train loss 0.0267, Train accuracy 43.91%\n",
      "Train loss 0.0265, Train accuracy 44.27%\n",
      "Train loss 0.0265, Train accuracy 44.32%\n",
      "Train loss 0.0262, Train accuracy 44.74%\n",
      "Train loss 0.0262, Train accuracy 44.92%\n",
      "Train loss 0.0261, Train accuracy 45.09%\n",
      "Train loss 0.0262, Train accuracy 45.14%\n",
      "Epoch 2\n",
      "Train loss 0.0269, Train accuracy 55.16%\n",
      "Train loss 0.0263, Train accuracy 47.05%\n",
      "Train loss 0.0259, Train accuracy 47.57%\n",
      "Train loss 0.0259, Train accuracy 47.78%\n",
      "Train loss 0.0256, Train accuracy 47.99%\n",
      "Train loss 0.0255, Train accuracy 48.15%\n",
      "Train loss 0.0255, Train accuracy 48.20%\n",
      "Train loss 0.0255, Train accuracy 48.12%\n",
      "End of training.\n",
      "\n",
      "End of testing. Test accuracy 45.39%\n"
     ]
    }
   ],
   "source": [
    "import torch\n",
    "import torchvision\n",
    "import torch.nn as nn\n",
    "import torchvision.transforms as t\n",
    "\n",
    "# define network structure \n",
    "net = nn.Sequential(nn.Linear(3 * 32 * 32, 1000), nn.ReLU(), nn.Linear(1000, 10))\n",
    "criterion = nn.CrossEntropyLoss()\n",
    "optimizer = torch.optim.SGD(net.parameters(), lr = 0.01, momentum=0.9, weight_decay=1e-4)\n",
    "\n",
    "# load data\n",
    "to_tensor =  t.ToTensor()\n",
    "normalize = t.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))\n",
    "flatten =  t.Lambda(lambda x:x.view(-1))\n",
    "\n",
    "transform_list = t.Compose([to_tensor, normalize, flatten])\n",
    "train_set = torchvision.datasets.CIFAR10(root='.', train=True, transform=transform_list, download=True)\n",
    "test_set = torchvision.datasets.CIFAR10(root='.', train=False, transform=transform_list, download=True)\n",
    "\n",
    "train_loader = torch.utils.data.DataLoader(train_set, batch_size=64)\n",
    "test_loader = torch.utils.data.DataLoader(test_set, batch_size=64)\n",
    "\n",
    "# === Train === ###\n",
    "net.train()\n",
    "\n",
    "# train loop\n",
    "for epoch in range(3):\n",
    "    train_correct = 0\n",
    "    train_loss = 0\n",
    "    print('Epoch {}'.format(epoch))\n",
    "    \n",
    "    # loop per epoch \n",
    "    for i, (batch, targets) in enumerate(train_loader):\n",
    "\n",
    "        output = net(batch)\n",
    "        loss = criterion(output, targets)\n",
    "\n",
    "        optimizer.zero_grad()\n",
    "        loss.backward()\n",
    "        optimizer.step()\n",
    "\n",
    "        pred = output.max(1, keepdim=True)[1]\n",
    "        train_correct += pred.eq(targets.view_as(pred)).sum().item()\n",
    "        train_loss += loss\n",
    "\n",
    "        if i % 100 == 10: print('Train loss {:.4f}, Train accuracy {:.2f}%'.format(\n",
    "            train_loss / (i * 64), 100 * train_correct / (i * 64)))\n",
    "        \n",
    "print('End of training.\\n')\n",
    "    \n",
    "# === Test === ###\n",
    "test_correct = 0\n",
    "net.eval()\n",
    "\n",
    "# loop, over whole test set\n",
    "for i, (batch, targets) in enumerate(test_loader):\n",
    "    \n",
    "    output = net(batch)\n",
    "    pred = output.max(1, keepdim=True)[1]\n",
    "    test_correct += pred.eq(targets.view_as(pred)).sum().item()\n",
    "    \n",
    "print('End of testing. Test accuracy {:.2f}%'.format(\n",
    "    100 * test_correct / (len(test_loader) * 64)))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Autograd"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Autograd handles well almost every basic tensor operation you could think of!"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([[1., 1., 1., 1., 0., 0., 0., 0., 0., 0.],\n",
       "        [1., 1., 1., 1., 0., 0., 0., 0., 0., 0.],\n",
       "        [1., 1., 1., 1., 0., 0., 0., 0., 0., 0.],\n",
       "        [1., 1., 1., 1., 0., 0., 0., 0., 0., 0.]])"
      ]
     },
     "execution_count": 3,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# don't hit enter before to guessed the answer!\n",
    "x = torch.Tensor(4, 10)\n",
    "x.requires_grad=True\n",
    "loss = x[:, :4].sum()\n",
    "loss.backward()\n",
    "x.grad"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([[4., 4., 4.],\n",
       "        [6., 6., 6.]])"
      ]
     },
     "execution_count": 4,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# don't hit enter before to guessed the answer!\n",
    "x = torch.Tensor(2, 3)\n",
    "x.requires_grad=True\n",
    "y = torch.Tensor([[1, 2], [3, 4]])\n",
    "loss = y.mm(x).sum()\n",
    "loss.backward()\n",
    "x.grad"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([[2., 0., 0., 0.]])"
      ]
     },
     "execution_count": 5,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# don't hit enter before to guessed the answer!\n",
    "x = torch.Tensor([[1, 2, 3, 4]])\n",
    "x.requires_grad=True\n",
    "y = 2 * x\n",
    "y.backward(torch.Tensor([[1, 0, 0, 0]]))\n",
    "x.grad"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Create your own module"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "tensor([[0, 1, 2, 3, 4, 5, 6, 7, 8, 9]])\n",
      "tensor([[5, 2, 3, 8, 7, 6, 0, 9, 1, 4]])\n"
     ]
    }
   ],
   "source": [
    "from torch.nn import Parameter\n",
    "\n",
    "class Permutation(nn.Module):\n",
    "\n",
    "    def __init__(self, input_features, axis=1):\n",
    "        super(Permutation, self).__init__()\n",
    "        self.input_features = input_features\n",
    "        self.axis = axis\n",
    "        self.perm = Parameter(torch.randperm(self.input_features), requires_grad=False)\n",
    "\n",
    "    def forward(self, input):\n",
    "        return input.index_select(self.axis, self.perm)\n",
    "\n",
    "    def __repr__(self):\n",
    "        return self.__class__.__name__ + '(' \\\n",
    "            + str(self.input_features) + ', ' \\\n",
    "            + str(self.output_features) + ', ' \\\n",
    "            + 'axis=' + str(self.axis) + ')'\n",
    "            \n",
    "net = Permutation(10)\n",
    "x = torch.arange(10).view(1, 10)\n",
    "print(x)\n",
    "y = net(x)\n",
    "print(y)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "OrderedDict([('perm', tensor([5, 2, 3, 8, 7, 6, 0, 9, 1, 4]))])\n"
     ]
    }
   ],
   "source": [
    "print(net.state_dict())"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Hooks\n",
    "\n",
    "Forward/backward hooks allow to catch the activations/backpropagated gradients! First, print the norm of the activations."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [],
   "source": [
    "# create a FC \n",
    "def deep_net(n_layers, features=1000):\n",
    "    layers = []\n",
    "    for _ in range(n_layers):\n",
    "        layers.append(nn.Linear(features, features))\n",
    "        layers.append(nn.ReLU())\n",
    "        \n",
    "    return nn.Sequential(*layers)\n",
    "\n",
    "# print activations norms\n",
    "def forward_hook(module, input, output):\n",
    "    print('Layer {}\\t Activations norm: {:.4f}'.format(\n",
    "        module.__class__.__name__, output.norm().item()))\n",
    "\n",
    "# register hook for every layer \n",
    "def register_forward_hooks(net, forward_hook):\n",
    "    for layer in net.children():\n",
    "        layer.register_forward_hook(forward_hook)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Layer Linear\t Activations norm: 412.4592\n",
      "Layer ReLU\t Activations norm: 291.7160\n",
      "Layer Linear\t Activations norm: 170.7109\n",
      "Layer ReLU\t Activations norm: 122.4180\n",
      "Layer Linear\t Activations norm: 71.3020\n",
      "Layer ReLU\t Activations norm: 48.7552\n",
      "Layer Linear\t Activations norm: 31.0433\n",
      "Layer ReLU\t Activations norm: 21.3409\n",
      "Layer Linear\t Activations norm: 17.6968\n",
      "Layer ReLU\t Activations norm: 12.1569\n",
      "Layer Linear\t Activations norm: 15.1106\n",
      "Layer ReLU\t Activations norm: 10.9613\n",
      "Layer Linear\t Activations norm: 14.2940\n",
      "Layer ReLU\t Activations norm: 10.0419\n",
      "Layer Linear\t Activations norm: 14.0194\n",
      "Layer ReLU\t Activations norm: 9.7039\n",
      "Layer Linear\t Activations norm: 14.0066\n",
      "Layer ReLU\t Activations norm: 9.7330\n",
      "Layer Linear\t Activations norm: 14.4152\n",
      "Layer ReLU\t Activations norm: 9.9619\n"
     ]
    }
   ],
   "source": [
    "net = deep_net(10)\n",
    "register_forward_hooks(net, forward_hook)\n",
    "\n",
    "x = torch.randn(512, 1000)\n",
    "y = net(x)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Is this network well initialized? Why?"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Next, print the size of the backpropagated gradients"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "metadata": {},
   "outputs": [],
   "source": [
    "# print  sizes\n",
    "def backward_hook(module, grad_input, grad_output):\n",
    "    print('Layer {}\\t Backpropagated gradient size: {}'.format(\n",
    "        module.__class__.__name__, grad_output[0].size()))\n",
    "\n",
    "# register hook for every layer \n",
    "def register_backward_hooks(net, backward_hook):\n",
    "    for layer in net.children():\n",
    "        layer.register_backward_hook(backward_hook)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 23,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Layer ReLU\t Backpropagated gradient size: torch.Size([512, 1000])\n",
      "Layer Linear\t Backpropagated gradient size: torch.Size([512, 1000])\n",
      "Layer ReLU\t Backpropagated gradient size: torch.Size([512, 1000])\n",
      "Layer Linear\t Backpropagated gradient size: torch.Size([512, 1000])\n",
      "Layer ReLU\t Backpropagated gradient size: torch.Size([512, 1000])\n",
      "Layer Linear\t Backpropagated gradient size: torch.Size([512, 1000])\n",
      "Layer ReLU\t Backpropagated gradient size: torch.Size([512, 1000])\n",
      "Layer Linear\t Backpropagated gradient size: torch.Size([512, 1000])\n",
      "Layer ReLU\t Backpropagated gradient size: torch.Size([512, 1000])\n",
      "Layer Linear\t Backpropagated gradient size: torch.Size([512, 1000])\n",
      "Layer ReLU\t Backpropagated gradient size: torch.Size([512, 1000])\n",
      "Layer Linear\t Backpropagated gradient size: torch.Size([512, 1000])\n",
      "Layer ReLU\t Backpropagated gradient size: torch.Size([512, 1000])\n",
      "Layer Linear\t Backpropagated gradient size: torch.Size([512, 1000])\n",
      "Layer ReLU\t Backpropagated gradient size: torch.Size([512, 1000])\n",
      "Layer Linear\t Backpropagated gradient size: torch.Size([512, 1000])\n",
      "Layer ReLU\t Backpropagated gradient size: torch.Size([512, 1000])\n",
      "Layer Linear\t Backpropagated gradient size: torch.Size([512, 1000])\n",
      "Layer ReLU\t Backpropagated gradient size: torch.Size([512, 1000])\n",
      "Layer Linear\t Backpropagated gradient size: torch.Size([512, 1000])\n"
     ]
    }
   ],
   "source": [
    "net = deep_net(10)\n",
    "register_backward_hooks(net, backward_hook)\n",
    "\n",
    "x = torch.rand(512, 1000)\n",
    "y = net(x).sum()\n",
    "y.backward()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "The backpropagated gradients are not to be mistaken with the gradients w.r.t the weights! Can you recover the gradients w.r.t the weights using the backpropagated gradients and the input of the layer? "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 24,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "How to recover this? tensor([[4.],\n",
      "        [6.]])\n",
      "Weights of second layer Parameter containing:\n",
      "tensor([[1., 2.],\n",
      "        [3., 4.]], requires_grad=True)\n",
      "Backpropagated gradients Parameter containing:\n",
      "tensor([[1., 2.],\n",
      "        [3., 4.]], requires_grad=True)\n"
     ]
    }
   ],
   "source": [
    "net = nn.Sequential(nn.Linear(1, 2, bias=False), nn.Linear(2, 1, bias=False))\n",
    "net[1].weight.data = torch.Tensor([[1, 2], [3, 4]])\n",
    "x = torch.ones(1, 1)\n",
    "\n",
    "loss = net(x).sum()\n",
    "loss.backward()\n",
    "print('How to recover this?', net[0].weight.grad)\n",
    "print('Weights of second layer', net[1].weight)\n",
    "print('Backpropagated gradients', net[1].weight)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Buffers"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 25,
   "metadata": {},
   "outputs": [],
   "source": [
    "class Normalize(nn.Module):\n",
    "\n",
    "    def __init__(self, num_features, momentum=0.1):\n",
    "        super(Normalize, self).__init__()\n",
    "        self.num_features = num_features\n",
    "        self.momentum = momentum\n",
    "        self.register_buffer('running_mean', torch.Tensor(num_features))\n",
    "        self.reset_parameters()\n",
    "\n",
    "    def reset_parameters(self):\n",
    "        self.running_mean.zero_()\n",
    "\n",
    "    def forward(self, input):\n",
    "        # training mode: use batch statistics and update running statistics\n",
    "        if self.training:\n",
    "            mean = input.mean(dim=0)\n",
    "            self.running_mean.mul_(1-self.momentum).add_(self.momentum, mean.data)\n",
    "\n",
    "        # eval mode: use running statistics\n",
    "        else:\n",
    "            mean = self.running_mean\n",
    "\n",
    "        # return output\n",
    "        return input - mean"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 26,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "tensor([0., 0., 0., 0., 0., 0., 0., 0., 0., 0.])\n",
      "tensor([-0.0505,  0.0084, -0.0412,  0.0265,  0.0038, -0.0156,  0.0816, -0.0134,\n",
      "         0.0900,  0.0227])\n",
      "tensor([-0.0505,  0.0084, -0.0412,  0.0265,  0.0038, -0.0156,  0.0816, -0.0134,\n",
      "         0.0900,  0.0227])\n",
      "tensor([-0.0505,  0.0084, -0.0412,  0.0265,  0.0038, -0.0156,  0.0816, -0.0134,\n",
      "         0.0900,  0.0227])\n"
     ]
    }
   ],
   "source": [
    "net = nn.Sequential(nn.Linear(10, 10), Normalize(10), nn.Linear(10, 2))\n",
    "x = torch.rand(3, 10)\n",
    "\n",
    "net.train()\n",
    "print(net[1].running_mean)\n",
    "net(x)\n",
    "print(net[1].running_mean)\n",
    "\n",
    "net.eval()\n",
    "print(net[1].running_mean)\n",
    "net(x)\n",
    "print(net[1].running_mean)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 27,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "OrderedDict([('0.weight', tensor([[ 0.1115, -0.1369,  0.0000, -0.0608, -0.1517,  0.1381, -0.2863, -0.1260,\n",
      "         -0.0241, -0.1595],\n",
      "        [ 0.1688,  0.0803,  0.2164, -0.1989, -0.0444, -0.1110, -0.0143, -0.0228,\n",
      "         -0.2681, -0.2556],\n",
      "        [ 0.1423, -0.0197, -0.0733,  0.1175, -0.1611, -0.2186, -0.0668, -0.2200,\n",
      "          0.2234, -0.0160],\n",
      "        [ 0.2108,  0.0021, -0.0887, -0.1745, -0.1142, -0.1653,  0.0424,  0.0575,\n",
      "          0.0745,  0.2680],\n",
      "        [ 0.0193,  0.2182, -0.2062,  0.1599, -0.1594, -0.1948,  0.1907, -0.1503,\n",
      "         -0.0568,  0.1625],\n",
      "        [ 0.1952, -0.2669,  0.1554, -0.2945, -0.0373,  0.1569, -0.2045, -0.2011,\n",
      "          0.0663, -0.2940],\n",
      "        [ 0.2757,  0.1726, -0.2297,  0.1615,  0.1225,  0.1398,  0.0226,  0.2855,\n",
      "          0.1712, -0.0851],\n",
      "        [ 0.0820, -0.2636,  0.2857, -0.2934,  0.1559,  0.1338, -0.1358, -0.1520,\n",
      "          0.0096, -0.1594],\n",
      "        [ 0.2417,  0.0266,  0.2303,  0.2760,  0.2888, -0.0185,  0.2606,  0.2783,\n",
      "         -0.2460, -0.0503],\n",
      "        [ 0.1312, -0.2788,  0.0943,  0.1774, -0.0821, -0.2727,  0.2744,  0.1967,\n",
      "         -0.0832, -0.1642]])), ('0.bias', tensor([-0.0647,  0.3051, -0.3000,  0.1348, -0.0216,  0.2280,  0.1179,  0.0771,\n",
      "         0.1266,  0.1670])), ('1.running_mean', tensor([-0.0505,  0.0084, -0.0412,  0.0265,  0.0038, -0.0156,  0.0816, -0.0134,\n",
      "         0.0900,  0.0227])), ('2.weight', tensor([[-0.1054,  0.0424, -0.1315, -0.2041, -0.1824, -0.0076, -0.0145, -0.0336,\n",
      "          0.2699,  0.0726],\n",
      "        [ 0.2079, -0.1339, -0.2676,  0.1141, -0.2049,  0.3115, -0.0338,  0.2356,\n",
      "         -0.2201, -0.0324]])), ('2.bias', tensor([-0.2399, -0.1917]))])\n"
     ]
    }
   ],
   "source": [
    "print(net.state_dict())"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Gradchecking"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 28,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "True\n"
     ]
    }
   ],
   "source": [
    "x = torch.rand(256, 2, requires_grad=True).double()\n",
    "y = torch.randint(0, 10, (256, ), requires_grad=True).double()\n",
    "custom_op = nn.Linear(2, 10).double()\n",
    "res = torch.autograd.gradcheck(custom_op, (x, ))\n",
    "print(res)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Autograd tips and tricks"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Pointers are everywhere!"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 29,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Parameter containing:\n",
      "tensor([[-0.0943,  0.6001],\n",
      "        [ 0.4884, -0.3737]], requires_grad=True)\n",
      "Parameter containing:\n",
      "tensor([[-0.1038,  0.5979],\n",
      "        [ 0.4789, -0.3759]], requires_grad=True)\n"
     ]
    }
   ],
   "source": [
    "net = nn.Linear(2, 2)\n",
    "w = net.weight\n",
    "print(w)\n",
    "\n",
    "x = torch.rand(1, 2)\n",
    "y = net(x).sum()\n",
    "y.backward()\n",
    "net.weight.data -= 0.01 * net.weight.grad # <--- What is this?\n",
    "print(w)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 30,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "tensor([[ 0.1687, -0.1027],\n",
      "        [-0.6148,  0.2257]], grad_fn=<CloneBackward>)\n",
      "tensor([[ 0.1687, -0.1027],\n",
      "        [-0.6148,  0.2257]], grad_fn=<CloneBackward>)\n"
     ]
    }
   ],
   "source": [
    "net = nn.Linear(2, 2)\n",
    "w = net.weight.clone()\n",
    "print(w)\n",
    "\n",
    "x = torch.rand(1, 2)\n",
    "y = net(x).sum()\n",
    "y.backward()\n",
    "net.weight.data -= 0.01 * net.weight.grad # <--- What is this?\n",
    "print(w)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Sharing weights "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 31,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "tensor([[ 0.7365,  0.7855],\n",
      "        [-0.6110, -0.6640]])\n",
      "tensor([[ 0.7365,  0.7855],\n",
      "        [-0.6110, -0.6640]])\n"
     ]
    }
   ],
   "source": [
    "net = nn.Sequential(nn.Linear(2, 2), nn.Linear(2, 2))\n",
    "net[0].weight = net[1].weight  # weight sharing\n",
    "\n",
    "x = torch.rand(1, 2)\n",
    "y = net(x).sum()\n",
    "y.backward()\n",
    "print(net[0].weight.grad)\n",
    "print(net[1].weight.grad)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.4"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
